{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "rX8mhOLljYeM"
   },
   "source": [
    "## Tensorflow使用底层API实现模型训练"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1. 导入包"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "0trJmd6DjqBZ"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "from tensorflow.keras.layers import Dense, Flatten\n",
    "from tensorflow.keras import Model\n",
    "\n",
    "tf.keras.backend.set_floatx('float64')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 下载手写数字数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "JqFRS6K07jJs"
   },
   "outputs": [],
   "source": [
    "mnist = tf.keras.datasets.mnist\n",
    "\n",
    "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
    "x_train, x_test = x_train / 255.0, x_test / 255.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(60000, 28, 28)"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(10000, 28, 28)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_test.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. 生成Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8Iu_quO024c2"
   },
   "outputs": [],
   "source": [
    "train_ds = tf.data.Dataset.from_tensor_slices(\n",
    "    (x_train, y_train)).shuffle(10000).batch(32)\n",
    "\n",
    "test_ds = tf.data.Dataset.from_tensor_slices(\n",
    "    (x_test, y_test)).batch(32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. 使用子类SubClassing API搭建模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "h3IKyzTCDNGo"
   },
   "outputs": [],
   "source": [
    "class MyModel(Model):\n",
    "    def __init__(self):\n",
    "        super(MyModel, self).__init__()\n",
    "        self.flatten = Flatten()\n",
    "        self.d1 = Dense(128, activation='relu')\n",
    "        self.d2 = Dense(10, activation='softmax')\n",
    "\n",
    "    def call(self, x):\n",
    "        x = self.flatten(x)\n",
    "        x = self.d1(x)\n",
    "        return self.d2(x)\n",
    "\n",
    "model = MyModel()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "uGih-c2LgbJu"
   },
   "source": [
    "### 5. 选择优化器与损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "u48C9WQ774n4"
   },
   "outputs": [],
   "source": [
    "# 损失函数\n",
    "loss_object = tf.keras.losses.SparseCategoricalCrossentropy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 优化器，用来优化参数\n",
    "optimizer = tf.keras.optimizers.Adam()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ix4mEL65on-w"
   },
   "source": [
    "### 6. 使用tf.GradientTape来训练模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 训练的指标\n",
    "train_loss = tf.keras.metrics.Mean(name='train_loss')\n",
    "train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "OZACiVqA8KQV"
   },
   "outputs": [],
   "source": [
    "# tf对被@tf.function装饰的函数进行AutoGraph优化，提升运算速度\n",
    "\n",
    "@tf.function\n",
    "def train_step(images, labels):\n",
    "    with tf.GradientTape() as tape:\n",
    "        predictions = model(images)\n",
    "        loss = loss_object(labels, predictions)\n",
    "        gradients = tape.gradient(loss, model.trainable_variables)\n",
    "        optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n",
    "\n",
    "    train_loss.update_state(loss)\n",
    "    train_accuracy.update_state(labels, predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "i-2pkctU_Ci7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, Loss: 0.26051258539160094, Accuracy: 92.72333333333333\n",
      "Epoch 2, Loss: 0.11415459908234576, Accuracy: 96.62333333333333\n",
      "Epoch 3, Loss: 0.07836290154916545, Accuracy: 97.595\n",
      "Epoch 4, Loss: 0.05948402446713299, Accuracy: 98.13333333333333\n",
      "Epoch 5, Loss: 0.04566015650037055, Accuracy: 98.585\n"
     ]
    }
   ],
   "source": [
    "EPOCHS = 5\n",
    "\n",
    "for epoch in range(EPOCHS):\n",
    "    train_loss.reset_states()\n",
    "    train_accuracy.reset_states()\n",
    "\n",
    "    for images, labels in train_ds:\n",
    "        train_step(images, labels)\n",
    "\n",
    "    print(f'Epoch {epoch+1}, Loss: {train_loss.result()}, Accuracy: {train_accuracy.result()*100}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 7. 测试评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_loss = tf.keras.metrics.Mean(name='test_loss')\n",
    "test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "for test_images, test_labels in test_ds:\n",
    "    predictions = model.predict(test_images)\n",
    "    \n",
    "    t_loss = loss_object(test_labels, predictions)\n",
    "    test_loss.update_state(t_loss)\n",
    "    \n",
    "    test_accuracy.update_state(test_labels, predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loss: 0.07311855612242536, Accuracy: 97.74000000000001\n"
     ]
    }
   ],
   "source": [
    "print(f'Loss: {test_loss.result()}, Accuracy: {test_accuracy.result()*100}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "advanced.ipynb",
   "private_outputs": true,
   "provenance": [],
   "toc_visible": true,
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
